{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ce528bb2",
   "metadata": {},
   "source": [
    "# Diffusion Sandbox\n",
    "\n",
    "In this notebook, the intricacies of a denosing diffusion framework are illustrated with the aid of simple snippets.\n",
    "\n",
    "First, let's import an image to use for the examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec7d6dc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import imageio\n",
    "\n",
    "mpl.rcParams['figure.figsize'] = (12, 8)\n",
    "\n",
    "img = torch.FloatTensor(imageio.imread('imgs/hills_2.png')/255)\n",
    "plt.imshow(img)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e932b7cd",
   "metadata": {},
   "source": [
    "### Before we start...\n",
    "The majority of the diffusion models assume that the images are scaled to the `[-1,+1]` range (which tends to simplify many equations). This tutorial will follow the same approach, so we need to define input and output transformation functions `input_T()` and `output_T()`.\n",
    "\n",
    "Also, let's define our own `show()` wrapper function that displays the image with automatic output transformation!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f2a5c35",
   "metadata": {},
   "outputs": [],
   "source": [
    "def input_T(input):\n",
    "    # [0,1] -> [-1,+1]\n",
    "    return 2*input-1\n",
    "    \n",
    "def output_T(input):\n",
    "    # [-1,+1] -> [0,1]\n",
    "    return (input+1)/2\n",
    "\n",
    "def show(input):\n",
    "    plt.imshow(output_T(input).clip(0,1))\n",
    "    \n",
    "img_=input_T(img)\n",
    "show(img_)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "823147d9",
   "metadata": {},
   "source": [
    "### Defining a schedule\n",
    "The diffusion process is built based on a variance schedule, which determines the levels of added noise at each step of the process. To that end, our schedule is defined below, with the following quantities:\n",
    "\n",
    "* `betas`:$\\beta_t$ \n",
    "\n",
    "\n",
    "* `alphas`: $\\alpha_t=1-\\beta_t$\n",
    "\n",
    "\n",
    "* `alphas_sqrt`:  $\\sqrt{\\alpha_t}$ \n",
    "\n",
    "\n",
    "* `alphas_prod`: $\\bar{\\alpha}_t=\\prod_{i=0}^{t}\\alpha_i$ \n",
    "\n",
    "\n",
    "* `alphas_prod_sqrt`: $\\sqrt{\\bar{\\alpha}_t}$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6494b544",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_timesteps=1000\n",
    "betas=torch.linspace(1e-4,2e-2,num_timesteps)\n",
    "\n",
    "alphas=1-betas\n",
    "alphas_sqrt=alphas.sqrt()\n",
    "alphas_cumprod=torch.cumprod(alphas,0)\n",
    "alphas_cumprod_sqrt=alphas_cumprod.sqrt()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4576f376",
   "metadata": {},
   "source": [
    "## Forward Process\n",
    "The forward process $q$ determines how subsequent steps in the diffusion are derived (gradual distortion of the original sample $x_0$).\n",
    "\n",
    "📃 First, let's bring up the key equations describing this process...\n",
    "\n",
    "Basic format of the forward step:\n",
    "$$q(x_t|x_{t−1}) := \\mathcal{N}(x_t; \\sqrt{1 − \\beta_t}x_{t−1}, \\beta_tI) \\tag{1}$$\n",
    "\n",
    "to step directly from $x_0$ to $x_t$:\n",
    "$$q(x_t|x_0) = \\mathcal{N}(x_t;\\sqrt{\\bar{\\alpha_t}}x_0, (1 − \\bar{\\alpha_t})I) \\tag{2}$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef306b72",
   "metadata": {},
   "source": [
    "### Let's define a function `forward_step()` that will allow us to use both $q(x_t|x_{t-1})$ and  `forward_jump()` for $q(x_t|x_0)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1976dfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_step(t, condition_img, return_noise=False):\n",
    "    \"\"\"\n",
    "        forward step: t-1 -> t\n",
    "    \"\"\"    \n",
    "    assert t >= 0\n",
    "\n",
    "    mean=alphas_sqrt[t]*condition_img    \n",
    "    std=betas[t].sqrt()\n",
    "      \n",
    "    # sampling from N\n",
    "    if not return_noise:\n",
    "        return mean+std*torch.randn_like(img)\n",
    "    else:\n",
    "        noise=torch.randn_like(img)\n",
    "        return mean+std*noise, noise\n",
    "    \n",
    "def forward_jump(t, condition_img, condition_idx=0, return_noise=False):\n",
    "    \"\"\"\n",
    "        forward jump: 0 -> t\n",
    "    \"\"\"   \n",
    "    assert t >= 0\n",
    "    \n",
    "    mean=alphas_cumprod_sqrt[t]*condition_img\n",
    "    std=(1-alphas_cumprod[t]).sqrt()\n",
    "      \n",
    "    # sampling from N\n",
    "    if not return_noise:\n",
    "        return mean+std*torch.randn_like(img)\n",
    "    else:\n",
    "        noise=torch.randn_like(img)\n",
    "        return mean+std*noise, noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e481f663",
   "metadata": {},
   "outputs": [],
   "source": [
    "N=5 # number of computed states between x_0 and x_T\n",
    "M=4 # number of samples taken from each distribution"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fea926a",
   "metadata": {},
   "source": [
    "In the first example, when `t==0`, we want to derive a sample $x_t$ based on the clean sample $x_0$!\n",
    "\n",
    "The first column shows the mean image for a given stage of the diffusion, and the subsequent columns to the right show several samples taken from the same distribution (they are different if you look closely!)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f3dc1e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12,8))\n",
    "for idx in range(N):\n",
    "    t_step=int(idx*(num_timesteps/N))\n",
    "    \n",
    "    plt.subplot(N,1+M,1+(M+1)*idx)\n",
    "    show(alphas_cumprod_sqrt[t_step]*img_)\n",
    "    plt.title(r'$\\mu_t=\\sqrt{\\bar{\\alpha}_t}x_0$') if idx==0 else None\n",
    "    plt.ylabel(\"t: {:.2f}\".format(t_step/num_timesteps))\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    \n",
    "    for sample in range(M):\n",
    "        x_t=forward_jump(t_step,img_)\n",
    "        \n",
    "        plt.subplot(N,1+M,2+(1+M)*idx+sample)\n",
    "        show(x_t)        \n",
    "        plt.axis('off')\n",
    "        \n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2639a95d",
   "metadata": {},
   "source": [
    "Alternatively, we can test the process of going from $x_{t-1}$ to $x_t$, which is a single step in the diffusion process. For that we can use the `forward_step` function.\n",
    "\n",
    "Note that the mean $\\mu_t$ is now a bit different (first column) since it is conditioned on a specific sample of $x_{t-1}$!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "842e1016",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12,8))\n",
    "for idx in range(N):\n",
    "    t_step=int(idx*(num_timesteps/N))\n",
    "    prev_img=forward_jump(max([0,t_step-1]),img_) # directly go to prev state\n",
    "    \n",
    "    plt.subplot(N,1+M,1+(M+1)*idx)\n",
    "    show(alphas_sqrt[t_step]*prev_img)\n",
    "    plt.title(r'$\\mu_t=\\sqrt{1-\\beta_t}x_{t-1}$') if idx==0 else None\n",
    "    plt.ylabel(\"t: {:.2f}\".format(t_step/num_timesteps))\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    \n",
    "    for sample in range(M):\n",
    "        plt.subplot(N,1+M,2+(1+M)*idx+sample)\n",
    "        x_t=forward_step(t_step,prev_img)\n",
    "        show(x_t)        \n",
    "        plt.axis('off')\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa53f91c",
   "metadata": {},
   "source": [
    "## Reverse Process\n",
    "\n",
    "The purpose of the reverse process $p$ is to approximate the previous step $x_{t-1}$ in the diffusion chain based on a sample $x_t$. In practice, this approximation $p(x_{t-1}|x_t)$ must be done without the knowledge of $x_0$.\n",
    "\n",
    "A parametrizable prediction model with parameters $\\theta$ is used to estimate $p_\\theta(x_{t-1}|x_t)$.\n",
    "\n",
    "The reverse process will also be (approximately) gaussian if the diffusion steps are small enough:\n",
    "\n",
    "$$p_\\theta(x_{t-1}|x_t) := \\mathcal{N}(x_{t-1};\\mu_\\theta(x_t),\\Sigma_\\theta(x_t))\\tag{3}$$\n",
    "\n",
    "In many works, it is assumed that the variance of this distribution should not depend strongly on $x_0$ or $x_t$, but rather on the stage of the diffusion process $t$. This can be observed in the true distribution $q(x_{t-1}|x_t, x_0)$, where the variance of the distribution equals $\\tilde{\\beta}_t$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b03ecb3d",
   "metadata": {},
   "source": [
    "### Parameterizing $\\mu_\\theta$\n",
    "There are at least 3 ways of parameterizing the mean of the reverse step distribution $p_\\theta(x_{t-1}|x_t)$:\n",
    "* Directly (a neural network will estimate $\\mu_\\theta$)\n",
    "* Via $x_0$ (a neural network will estimate $x_0$)\n",
    "$$\\tilde{\\mu}_\\theta = \\frac{\\sqrt{\\bar{\\alpha}_{t-1}}\\beta_t}{1-\\bar{\\alpha}_t}x_0 + \\frac{\\sqrt{\\alpha_t}(1-\\bar{\\alpha}_{t-1})}{1-\\bar{\\alpha}_t}x_t\\tag{4}$$\n",
    "* Via noise $\\epsilon$ subtraction from $x_0$ (a neural network will estimate $\\epsilon$)\n",
    "$$x_0=\\frac{1}{\\sqrt{\\bar{\\alpha}_t}}(x_t-\\sqrt{1-\\bar{\\alpha}_t}\\epsilon)\\tag{5}$$\n",
    "\n",
    "The approach of approximating the normal noise $\\epsilon$ is used most widely.\n",
    "\n",
    "Let's look at what an example $\\epsilon$ might look like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "596c106b",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_step=50\n",
    "\n",
    "x_t,noise=forward_jump(t_step,img_,return_noise=True)\n",
    "\n",
    "plt.subplot(1,3,1)\n",
    "show(img_)\n",
    "plt.title(r'$x_0$')\n",
    "plt.axis('off')\n",
    "plt.subplot(1,3,2)\n",
    "show(x_t)\n",
    "plt.title(r'$x_t$')\n",
    "plt.axis('off')\n",
    "plt.subplot(1,3,3)\n",
    "show(noise)\n",
    "plt.title(r'$\\epsilon$')\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91c01ffb",
   "metadata": {},
   "source": [
    "If $\\epsilon$ is predicted correctly, we can use the equation (5) to predict $x_0$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d36a475",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_0_pred=(x_t-(1-alphas_cumprod[t_step]).sqrt()*noise)/(alphas_cumprod_sqrt[t_step])\n",
    "\n",
    "plt.subplot(1,3,1)\n",
    "show(x_t)\n",
    "plt.title('$x_t$ ($\\ell_1$: {:.3f})'.format(F.l1_loss(x_t,img_)))\n",
    "plt.axis('off')\n",
    "plt.subplot(1,3,2)\n",
    "show(x_0_pred)\n",
    "plt.title('$x_0$ prediction ($\\ell_1$: {:.3f})'.format(F.l1_loss(x_0_pred,img_)))\n",
    "plt.axis('off') \n",
    "plt.subplot(1,3,3)\n",
    "show(img_)\n",
    "plt.title('$x_0$')\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8841689c",
   "metadata": {},
   "source": [
    "Approximation (or knowledge) of $x_0$ allows us to approximate the mean of the step $t-1$, using (4)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f37fbbbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# estimate mean\n",
    "mean_pred=x_0_pred*(alphas_cumprod_sqrt[t_step-1]*betas[t_step])/(1-alphas_cumprod[t_step]) + x_t*(alphas_sqrt[t_step]*(1-alphas_cumprod[t_step-1]))/(1-alphas_cumprod[t_step])\n",
    "\n",
    "# let's compare it to ground truth mean of the previous step (requires knowledge of x_0)\n",
    "mean_gt=alphas_cumprod_sqrt[t_step-1]*img_"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ffd11d7",
   "metadata": {},
   "source": [
    "Since reverse process mean estimation $\\tilde{\\mu}_\\theta$ in (4) is effectively linear interpolation between noisy $x_t$ and $x_0$ it is expected to have a higher error (as the additive noise is still present) compared to the mean computed using the forward process (which is computed by scaling the clean sample by a scalar value)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7593d85c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.subplot(1,3,1)\n",
    "show(x_t)\n",
    "plt.title('$x_t$   ($\\ell_1$: {:.3f})'.format(F.l1_loss(x_t,img_)))\n",
    "plt.subplot(1,3,2)\n",
    "show(mean_pred)\n",
    "plt.title(r'$\\tilde{\\mu}_{t-1}$' + '  ($\\ell_1$: {:.3f})'.format(F.l1_loss(mean_pred,img_)))\n",
    "plt.subplot(1,3,3)\n",
    "show(mean_gt)\n",
    "plt.title(r'$\\mu_{t-1}$' + '  ($\\ell_1$: {:.3f})'.format(F.l1_loss(mean_gt,img_)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d30f0d7",
   "metadata": {},
   "source": [
    "Once we get our `mean_pred` ($\\tilde{\\mu_{t}}$), we can define our distribution for the previous step\n",
    "\n",
    "$$\\tilde{\\beta}_t=\\beta_t \\tag{6}$$\n",
    "\n",
    "$$ p_\\theta(x_{t-1}|x_t) := \\mathcal{N}(x_{t-1};\\tilde{\\mu}_\\theta(x_t,t),\\tilde{\\beta}_t I) \\tag{7}$$\n",
    "\n",
    "> Important: the experiment below should be treated as a simulation. In practice, the network must  predict either $\\epsilon$ or $x_0$ or $\\tilde{\\mu}_\\theta$. Here, the value of $epsilon$ is simply subs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4431810f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reverse_step(epsilon, x_t, t_step, return_noise=False):\n",
    "    \n",
    "    # estimate x_0 based on epsilon\n",
    "    x_0_pred=(x_t-(1-alphas_cumprod[t_step]).sqrt()*epsilon)/(alphas_cumprod_sqrt[t_step])\n",
    "    if t_step==0:\n",
    "        sample=x_0_pred\n",
    "        noise=torch.zeros_like(x_0_pred)\n",
    "    else:\n",
    "        # estimate mean\n",
    "        mean_pred=x_0_pred*(alphas_cumprod_sqrt[t_step-1]*betas[t_step])/(1-alphas_cumprod[t_step]) + x_t*(alphas_sqrt[t_step]*(1-alphas_cumprod[t_step-1]))/(1-alphas_cumprod[t_step])\n",
    "\n",
    "        # compute variance\n",
    "        beta_pred=betas[t_step].sqrt() if t_step != 0 else 0\n",
    "\n",
    "        sample=mean_pred+beta_pred*torch.randn_like(x_t)\n",
    "        # this noise is only computed for simulation purposes (since x_0_pred is not known normally)\n",
    "        noise=(sample-x_0_pred*alphas_cumprod_sqrt[t_step-1])/(1-alphas_cumprod[t_step-1]).sqrt()\n",
    "    if return_noise:\n",
    "        return sample, noise\n",
    "    else:\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6cf5f8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_t,noise=forward_jump(1000-1,img_,return_noise=True)\n",
    "\n",
    "state_imgs=[x_t.numpy()]\n",
    "for t_step in reversed(range(1000)):\n",
    "    x_t,noise=reverse_step(noise,x_t,t_step,return_noise=True)\n",
    "    \n",
    "    if t_step % 200 == 0:\n",
    "        state_imgs.append(x_t.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d26eb9a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "for idx,state_img in enumerate(state_imgs):\n",
    "    plt.subplot(1,len(state_imgs),idx+1)\n",
    "    show(state_img.clip(-1,1))\n",
    "    plt.axis('off')\n",
    "    \n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "244d9302",
   "metadata": {},
   "source": [
    "## Packaging into Components\n",
    "The processes investigated above are neatly packaged into modular components for easier management of the diffusion framework.\n",
    "\n",
    "First, the forward process component `GaussianForwardProcess` encapsulates the functions of $q(x_t|x_0)$ and $q(x_t|x_{t-1})$.\n",
    "\n",
    "Below, we can see how different schedules of the variance parameter $\\beta$ affect how the noise level changes throughout the progression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc08885b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.DenoisingDiffusionProcess import *\n",
    "\n",
    "D=128\n",
    "make_white=False\n",
    "save=False\n",
    "line_color='black' #'#9EFFB9'\n",
    "test_img=img[256-D:256+D,512-D:512+D,:]\n",
    "\n",
    "for schedule in ['linear','quadratic','sigmoid','cosine']:\n",
    "    fw=GaussianForwardProcess(1000,\n",
    "                              schedule)\n",
    "\n",
    "    plt.figure(figsize=(10,2))\n",
    "    plt.subplot(1,6,1)    \n",
    "    plt.plot(fw.betas,color=line_color)\n",
    "    plt.title(schedule,color=line_color)\n",
    "    plt.xlabel(r'step $t$',color=line_color)\n",
    "    plt.ylabel(r'$\\beta_t$',color=line_color)\n",
    "    \n",
    "    if make_white:\n",
    "        plt.xticks(color='white')\n",
    "        plt.gca().tick_params(axis='x', colors='white')\n",
    "        plt.gca().tick_params(axis='y', colors='white')\n",
    "        plt.gca().spines['top'].set_color('white')\n",
    "        plt.gca().spines['right'].set_color('white')\n",
    "        plt.gca().spines['left'].set_color('white')\n",
    "        plt.gca().spines['bottom'].set_color('white')\n",
    "    for step in range(5):\n",
    "        plt.subplot(1,6,step+2)\n",
    "        plt.imshow(fw(test_img.permute(2,0,1).unsqueeze(0),torch.tensor(step*200))[0].permute(1,2,0).clip(0,1))\n",
    "        plt.axis('off')        \n",
    "    plt.tight_layout()\n",
    "    \n",
    "    \n",
    "    if save:\n",
    "        plt.savefig('{}.png'.format(schedule),\n",
    "                    dpi=200,\n",
    "                    bbox_inches='tight',\n",
    "                    pad_inches=0.0,\n",
    "                    transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1489a8c5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
